{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TFLite保存与解释与量化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)\n",
      "matplotlib 3.1.2\n",
      "numpy 1.18.1\n",
      "pandas 0.25.3\n",
      "sklearn 0.22.1\n",
      "tensorflow 2.0.0\n",
      "tensorflow_core.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "for gpu in gpus:\n",
    "    tf.config.experimental.set_memory_growth(gpu, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tflite_dir = os.path.join('tflite_models')\n",
    "if not os.path.exists(tflite_dir):\n",
    "    os.mkdir(tflite_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Keras模型转TFLite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=492, shape=(1, 10), dtype=float32, numpy=\n",
       "array([[0.11120097, 0.03724404, 0.06243855, 0.01093319, 0.03455226,\n",
       "        0.00780652, 0.1908657 , 0.02144313, 0.49145293, 0.03206269]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载 keras模型\n",
    "loaded_keras_model = keras.models.load_model(\n",
    "    './graph_def_and_weights/fashion_mnist_model.h5')\n",
    "loaded_keras_model(np.ones((1, 28, 28)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建转化器\n",
    "keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)\n",
    "# 转化为tflite\n",
    "keras_tflite = keras_to_tflite_converter.convert()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将tflite写入到文件\n",
    "tflite_file = os.path.join(tflite_dir, 'keras_tflite')\n",
    "\n",
    "with open(tflite_file, 'wb') as f:\n",
    "    f.write(keras_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concrete Function 转 TFLite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用tf.function封装 Keras model\n",
    "run_model = tf.function(lambda x : loaded_keras_model(x))\n",
    "# 转化为concrete_function\n",
    "keras_concrete_func = run_model.get_concrete_function(\n",
    "    tf.TensorSpec(loaded_keras_model.inputs[0].shape, loaded_keras_model.inputs[0].dtype))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=620, shape=(1, 10), dtype=float32, numpy=\n",
       "array([[0.11120097, 0.03724404, 0.06243855, 0.01093319, 0.03455226,\n",
       "        0.00780652, 0.1908657 , 0.02144313, 0.49145293, 0.03206269]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keras_concrete_func(tf.constant(np.ones((1, 28, 28), dtype = np.float32)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建转化器\n",
    "concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions(\n",
    "    [keras_concrete_func])\n",
    "# 转化为tflite\n",
    "concrete_func_tflite = concrete_func_to_tflite_converter.convert()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将tflite写入到文件\n",
    "tflite_file = os.path.join(tflite_dir, 'concrete_func_tflite')\n",
    "\n",
    "with open(tflite_file, 'wb') as f:\n",
    "    f.write(concrete_func_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SavedModel 转 TFLite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建转化器\n",
    "saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model(\n",
    "    './keras_saved_graph')\n",
    "# 转化为tflite\n",
    "saved_model_tflite = saved_model_to_tflite_converter.convert()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将tflite写入到文件\n",
    "tflite_file = os.path.join(tflite_dir, 'saved_model_tflite')\n",
    "\n",
    "with open(tflite_file, 'wb') as f:\n",
    "    f.write(saved_model_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用tflite interpreter 运行TFLite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取tflite文件\n",
    "with open('tflite_models/concrete_func_tflite', 'rb') as f:\n",
    "    concrete_func_tflite = f.read()\n",
    "\n",
    "# 创建interpreter\n",
    "interpreter = tf.lite.Interpreter(model_content=concrete_func_tflite)\n",
    "\n",
    "# 分配内存\n",
    "interpreter.allocate_tensors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'name': 'x', 'index': 10, 'shape': array([ 1, 28, 28]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]\n",
      "[{'name': 'Identity', 'index': 0, 'shape': array([ 1, 10]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]\n"
     ]
    }
   ],
   "source": [
    "# 获取输入输出信息\n",
    "input_details = interpreter.get_input_details()\n",
    "output_details = interpreter.get_output_details()\n",
    "\n",
    "print(input_details)\n",
    "print(output_details)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.11120094 0.03724403 0.06243857 0.01093319 0.03455226 0.00780652\n",
      "  0.19086568 0.02144315 0.49145293 0.0320627 ]]\n"
     ]
    }
   ],
   "source": [
    "# 定义输出\n",
    "input_data = tf.constant(np.ones(input_details[0]['shape'], dtype=np.float32))\n",
    "interpreter.set_tensor(input_details[0]['index'], input_data)\n",
    "\n",
    "# 进行预测\n",
    "interpreter.invoke()\n",
    "\n",
    "# 获取输出\n",
    "output_results = interpreter.get_tensor(output_details[0]['index'])\n",
    "print(output_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 量化后保存\n",
    "### keras model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建转化器\n",
    "keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)\n",
    "# 量化\n",
    "keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
    "# 转化为tflite\n",
    "keras_tflite = keras_to_tflite_converter.convert()\n",
    "\n",
    "# 将量化后tflite写入到文件\n",
    "tflite_file = os.path.join(tflite_dir, 'quantized_keras_tflite')\n",
    "with open(tflite_file, 'wb') as f:\n",
    "    f.write(keras_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### concrete func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建转化器\n",
    "concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions(\n",
    "    [keras_concrete_func])\n",
    "# 量化\n",
    "concrete_func_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
    "# 转化为tflite\n",
    "concrete_func_tflite = concrete_func_to_tflite_converter.convert()\n",
    "\n",
    "# 将tflite写入到文件\n",
    "tflite_file = os.path.join(tflite_dir, 'quantized_concrete_func_tflite')\n",
    "with open(tflite_file, 'wb') as f:\n",
    "    f.write(concrete_func_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saved Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建转化器\n",
    "saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model(\n",
    "    './keras_saved_graph')\n",
    "# 量化\n",
    "saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
    "# 转化为tflite\n",
    "saved_model_tflite = saved_model_to_tflite_converter.convert()\n",
    "\n",
    "tflite_file = os.path.join(tflite_dir, 'quantized_saved_model_tflite')\n",
    "with open(tflite_file, 'wb') as f:\n",
    "    f.write(saved_model_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用tflite interpreter 运行量化后的TFLite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取tflite文件\n",
    "with open('tflite_models/quantized_concrete_func_tflite', 'rb') as f:\n",
    "    concrete_func_tflite = f.read()\n",
    "\n",
    "# 创建interpreter\n",
    "interpreter = tf.lite.Interpreter(model_content=concrete_func_tflite)\n",
    "\n",
    "# 分配内存\n",
    "interpreter.allocate_tensors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'name': 'x', 'index': 10, 'shape': array([ 1, 28, 28]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]\n",
      "[{'name': 'Identity', 'index': 0, 'shape': array([ 1, 10]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]\n"
     ]
    }
   ],
   "source": [
    "# 获取输入输出信息\n",
    "input_details = interpreter.get_input_details()\n",
    "output_details = interpreter.get_output_details()\n",
    "\n",
    "print(input_details)\n",
    "print(output_details)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.11255115 0.03726469 0.06337408 0.01115514 0.03467338 0.00785631\n",
      "  0.19097982 0.02149594 0.48883578 0.0318136 ]]\n"
     ]
    }
   ],
   "source": [
    "# 定义输出\n",
    "input_data = tf.constant(np.ones(input_details[0]['shape'], dtype=np.float32))\n",
    "interpreter.set_tensor(input_details[0]['index'], input_data)\n",
    "\n",
    "# 进行预测\n",
    "interpreter.invoke()\n",
    "\n",
    "# 获取输出\n",
    "output_results = interpreter.get_tensor(output_details[0]['index'])\n",
    "print(output_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
